import os
import sys
import re
import tqdm
import codecs
from typing import List, Dict, Union

# define a read file function here


def read_file(file_path: str, encoding: str = 'utf-8') -> List[str]:
    # start to implementation
    with open(file_path, 'r') as f:
        keys = set()
        for line in f:
            key = line.strip()
            keys.add(key)
    # define a funcition, given a list of keys, and remove the duplicated elements from the list
    # then sort the list

    return keys


def write_kaldi_format_data(input_keys):
    """_summary_

    Args:
        input_keys (List): 
    """
    # input file has one key every line, this Function used to generate a kaldi format data directory
    input_keys = []
    with codecs.open(input_keys, 'r', 'utf-8') as f:
        for line in f:
            input_keys.append(line.strip())
    return input_keys


def get_text_keys(text_FileName):
    if not os.path.exists(text_FileName):
        raise ValueError("{} does not exist".format(text_FileName))
    text_keys = {}
    with open(text_FileName, 'r') as f:
        for line in f:
            linetokens = re.split(r'\s+', line.strip(), maxsplit=1)
            if len(linetokens) == 2:
                text_keys[linetokens[0]] = linetokens[1]
    return text_keys


def fix_data_dir(data_dir):
    if not os.path.exists(data_dir + '/text') or not os.path.exists(data_dir + '/wav.scp') or not os.path.exists(data_dir + '/spk2utt') or not os.path.exists(data_dir + '/utt2spk'):
        raise ValueError('data_dir must be a directory')
    total_text_keys = get_text_keys(data_dir + '/text')
    total_wav_keys = get_text_keys(data_dir + '/wav.scp')
    unified_keys = set()
    unified_keys.update(total_text_keys.keys())
    temp = unified_keys.intersection(set(total_wav_keys.keys()))
    unified_keys = temp
    with codecs.open('{}/{}.normalized'.format(data_dir, 'text'), 'w') as ft, codecs.open('{}/{}.normalized'.format(data_dir, 'wav.scp'), mode='w') as fw, codecs.open('{}/{}.normalized'.format(data_dir, 'utt2spk'), mode='w') as fu, codecs.open('{}/{}.normalized'.format(data_dir, 'spk2utt'), mode='w') as fs:
        for key in unified_keys:
            ft.write('{}\t{}\n'.format(key, total_text_keys[key]))
            fw.write('{}\t{}\n'.format(key, total_wav_keys[key]))
            fu.write('{}\t{}\n'.format(key, key))
            fs.write('{}\t{}\n'.format(key, key))


if __name__ == '__main__':
    fix_data_dir(sys.argv[1])
